{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VBNcpB9c5TxB"
      },
      "source": [
        "# MongoDB As A Toolbox For Agentic Systems\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mongodb-developer/GenAI-Showcase/blob/main/notebooks/advanced_techniques/function_calling_mongodb_as_a_toolbox.ipynb)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6s3dlQKnRkFL",
        "outputId": "900958a6-f3b4-4122-dbdb-c5bf1d694128"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[?25l   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/361.3 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K   \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m358.4/361.3 kB\u001b[0m \u001b[31m11.5 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m361.3/361.3 kB\u001b[0m \u001b[31m7.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m21.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m307.7/307.7 kB\u001b[0m \u001b[31m13.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.6/75.6 kB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.9/77.9 kB\u001b[0m \u001b[31m3.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m318.9/318.9 kB\u001b[0m \u001b[31m14.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "!pip install --quiet openai pymongo"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Mj2FLQpkUKcl",
        "outputId": "aa40d1f5-3659-4672-baa9-d3619473eb42"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "OpenAI API Key: ··········\n",
            "Enter MongoDB URI: ··········\n"
          ]
        }
      ],
      "source": [
        "import getpass\n",
        "import json\n",
        "import os\n",
        "\n",
        "OPENAI_API_KEY = getpass.getpass(\"OpenAI API Key: \")\n",
        "os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY\n",
        "\n",
        "MONGO_URI = getpass.getpass(\"Enter MongoDB URI: \")\n",
        "os.environ[\"MONGO_URI\"] = MONGO_URI\n",
        "\n",
        "GPT_MODEL = \"gpt-4o\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "aNa8zRCUPz93"
      },
      "outputs": [],
      "source": [
        "import openai\n",
        "\n",
        "client = openai.OpenAI()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8-tG0aX-B3ca"
      },
      "source": [
        "## Define MongoDB Tool Decorator"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "oM0G-evHCMUv"
      },
      "outputs": [],
      "source": [
        "import pymongo\n",
        "\n",
        "# Get MongoClient\n",
        "mongo_client = pymongo.MongoClient(MONGO_URI, appname=\"showcase.tools.mongodb_toolbox\")\n",
        "\n",
        "# Get database\n",
        "db = mongo_client[\"function_calling_db\"]\n",
        "\n",
        "# Get collection\n",
        "tools_collection = db[\"tools\"]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "EJe1J0QOB_AB"
      },
      "outputs": [],
      "source": [
        "import inspect\n",
        "from functools import wraps\n",
        "from typing import get_type_hints\n",
        "\n",
        "\n",
        "def get_embedding(text, model=\"text-embedding-3-small\"):\n",
        "    text = text.replace(\"\\n\", \" \")\n",
        "    return client.embeddings.create(input=[text], model=model).data[0].embedding\n",
        "\n",
        "\n",
        "def mongodb_toolbox(collection=tools_collection):\n",
        "    def decorator(func):\n",
        "        @wraps(func)\n",
        "        def wrapper(*args, **kwargs):\n",
        "            return func(*args, **kwargs)\n",
        "\n",
        "        # Generate tool definition\n",
        "        signature = inspect.signature(func)\n",
        "        docstring = inspect.getdoc(func) or \"\"\n",
        "        type_hints = get_type_hints(func)\n",
        "\n",
        "        tool_def = {\n",
        "            \"name\": func.__name__,\n",
        "            \"description\": docstring.strip(),\n",
        "            \"parameters\": {\"type\": \"object\", \"properties\": {}, \"required\": []},\n",
        "        }\n",
        "\n",
        "        for param_name, param in signature.parameters.items():\n",
        "            if (\n",
        "                param.kind == inspect.Parameter.VAR_POSITIONAL\n",
        "                or param.kind == inspect.Parameter.VAR_KEYWORD\n",
        "            ):\n",
        "                continue\n",
        "\n",
        "            param_type = type_hints.get(param_name, type(None))\n",
        "            json_type = \"string\"  # Default to string\n",
        "            if param_type in (int, float):\n",
        "                json_type = \"number\"\n",
        "            elif param_type is bool:\n",
        "                json_type = \"boolean\"\n",
        "\n",
        "            tool_def[\"parameters\"][\"properties\"][param_name] = {\n",
        "                \"type\": json_type,\n",
        "                \"description\": f\"Parameter {param_name}\",\n",
        "            }\n",
        "\n",
        "            if param.default == inspect.Parameter.empty:\n",
        "                tool_def[\"parameters\"][\"required\"].append(param_name)\n",
        "\n",
        "        tool_def[\"parameters\"][\"additionalProperties\"] = False\n",
        "\n",
        "        # Store in MongoDB\n",
        "        vector = get_embedding(tool_def[\"description\"])\n",
        "        tool_doc = {**tool_def, \"embedding\": vector}\n",
        "        collection.update_one({\"name\": func.__name__}, {\"$set\": tool_doc}, upsert=True)\n",
        "\n",
        "        return wrapper\n",
        "\n",
        "    return decorator"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "NmFa7RrZTKyk"
      },
      "outputs": [],
      "source": [
        "def vector_search(user_query, collection):\n",
        "    \"\"\"\n",
        "    Perform a vector search in the MongoDB collection based on the user query.\n",
        "\n",
        "    Args:\n",
        "    user_query (str): The user's query string.\n",
        "    collection (MongoCollection): The MongoDB collection to search.\n",
        "\n",
        "    Returns:\n",
        "    list: A list of matching documents.\n",
        "    \"\"\"\n",
        "\n",
        "    # Generate embedding for the user query\n",
        "    query_embedding = get_embedding(user_query)\n",
        "\n",
        "    if query_embedding is None:\n",
        "        return \"Invalid query or embedding generation failed.\"\n",
        "\n",
        "    # Define the vector search pipeline\n",
        "    vector_search_stage = {\n",
        "        \"$vectorSearch\": {\n",
        "            \"index\": \"vector_index\",\n",
        "            \"queryVector\": query_embedding,\n",
        "            \"path\": \"embedding\",\n",
        "            \"numCandidates\": 150,  # Number of candidate matches to consider\n",
        "            \"limit\": 2,  # Return top 5 matches\n",
        "        }\n",
        "    }\n",
        "\n",
        "    unset_stage = {\n",
        "        \"$unset\": \"embedding\"  # Exclude the 'embedding' field from the results\n",
        "    }\n",
        "\n",
        "    pipeline = [vector_search_stage, unset_stage]\n",
        "\n",
        "    # Execute the search\n",
        "    results = collection.aggregate(pipeline)\n",
        "    return list(results)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "Uvko9J-0SCn4"
      },
      "outputs": [],
      "source": [
        "import random\n",
        "from datetime import datetime\n",
        "\n",
        "\n",
        "@mongodb_toolbox()\n",
        "def shout(statement: str) -> str:\n",
        "    \"\"\"\n",
        "    Convert a statement to uppercase letters to emulate shouting. Use this when a user wants to emphasize something strongly or when they explicitly ask to 'shout' something..\n",
        "\n",
        "    \"\"\"\n",
        "    return statement.upper()\n",
        "\n",
        "\n",
        "@mongodb_toolbox()\n",
        "def get_weather(location: str, unit: str = \"celsius\") -> str:\n",
        "    \"\"\"\n",
        "    Get the current weather for a specified location.\n",
        "    Use this when a user asks about the weather in a specific place.\n",
        "\n",
        "    :param location: The name of the city or location to get weather for.\n",
        "    :param unit: The temperature unit, either 'celsius' or 'fahrenheit'. Defaults to 'celsius'.\n",
        "    :return: A string describing the current weather.\n",
        "    \"\"\"\n",
        "    conditions = [\"sunny\", \"cloudy\", \"rainy\", \"snowy\"]\n",
        "    temperature = random.randint(-10, 35)\n",
        "\n",
        "    if unit.lower() == \"fahrenheit\":\n",
        "        temperature = (temperature * 9 / 5) + 32\n",
        "\n",
        "    condition = random.choice(conditions)\n",
        "    return f\"The weather in {location} is currently {condition} with a temperature of {temperature}°{'C' if unit.lower() == 'celsius' else 'F'}.\"\n",
        "\n",
        "\n",
        "@mongodb_toolbox()\n",
        "def get_stock_price(symbol: str) -> str:\n",
        "    \"\"\"\n",
        "    Get the current stock price for a given stock symbol.\n",
        "    Use this when a user asks about the current price of a specific stock.\n",
        "\n",
        "    :param symbol: The stock symbol to look up (e.g., 'AAPL' for Apple Inc.).\n",
        "    :return: A string with the current stock price.\n",
        "    \"\"\"\n",
        "    price = round(random.uniform(10, 1000), 2)\n",
        "    return f\"The current stock price of {symbol} is ${price}.\"\n",
        "\n",
        "\n",
        "@mongodb_toolbox()\n",
        "def get_current_time(timezone: str = \"UTC\") -> str:\n",
        "    \"\"\"\n",
        "    Get the current time for a specified timezone.\n",
        "    Use this when a user asks about the current time in a specific timezone.\n",
        "\n",
        "    :param timezone: The timezone to get the current time for. Defaults to 'UTC'.\n",
        "    :return: A string with the current time in the specified timezone.\n",
        "    \"\"\"\n",
        "    current_time = datetime.utcnow().strftime(\"%H:%M:%S\")\n",
        "    return f\"The current time in {timezone} is {current_time}.\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "uurRWM6_TpUn"
      },
      "outputs": [],
      "source": [
        "def populate_tools(search_results):\n",
        "    \"\"\"\n",
        "    Populate the tools array based on the results from the vector search.\n",
        "\n",
        "    Args:\n",
        "    search_results (list): The list of documents returned from the vector search.\n",
        "\n",
        "    Returns:\n",
        "    list: A list of tool definitions in the format required by the OpenAI API.\n",
        "    \"\"\"\n",
        "    tools = []\n",
        "    for result in search_results:\n",
        "        tool = {\n",
        "            \"type\": \"function\",\n",
        "            \"function\": {\n",
        "                \"name\": result[\"name\"],\n",
        "                \"description\": result[\"description\"],\n",
        "                \"parameters\": result[\"parameters\"],\n",
        "            },\n",
        "        }\n",
        "        tools.append(tool)\n",
        "    return tools"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "PhOBBC_CSY5E"
      },
      "outputs": [],
      "source": [
        "user_query = \"Hi, can you shout the statement: We are there\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "Pf7KJni_TidX"
      },
      "outputs": [],
      "source": [
        "tools_related_to_user_query = vector_search(user_query, tools_collection)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "BZWB-_7mTTTy"
      },
      "outputs": [],
      "source": [
        "tools = populate_tools(tools_related_to_user_query)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5e7rVZZeSpEd",
        "outputId": "7ca2c3ff-5ac8-4910-c9eb-414116b287a1"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[{'function': {'description': 'Convert a statement to uppercase letters to '\n",
            "                              'emulate shouting. Use this when a user wants to '\n",
            "                              'emphasize something strongly or when they '\n",
            "                              \"explicitly ask to 'shout' something..\",\n",
            "               'name': 'shout',\n",
            "               'parameters': {'additionalProperties': False,\n",
            "                              'properties': {'statement': {'description': 'Parameter '\n",
            "                                                                          'statement',\n",
            "                                                           'type': 'string'}},\n",
            "                              'required': ['statement'],\n",
            "                              'type': 'object'}},\n",
            "  'type': 'function'},\n",
            " {'function': {'description': 'Get the current stock price for a given stock '\n",
            "                              'symbol.\\n'\n",
            "                              'Use this when a user asks about the current '\n",
            "                              'price of a specific stock.\\n'\n",
            "                              '\\n'\n",
            "                              ':param symbol: The stock symbol to look up '\n",
            "                              \"(e.g., 'AAPL' for Apple Inc.).\\n\"\n",
            "                              ':return: A string with the current stock price.',\n",
            "               'name': 'get_stock_price',\n",
            "               'parameters': {'additionalProperties': False,\n",
            "                              'properties': {'symbol': {'description': 'Parameter '\n",
            "                                                                       'symbol',\n",
            "                                                        'type': 'string'}},\n",
            "                              'required': ['symbol'],\n",
            "                              'type': 'object'}},\n",
            "  'type': 'function'}]\n"
          ]
        }
      ],
      "source": [
        "import pprint\n",
        "\n",
        "pprint.pprint(tools)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "zOidCss5TYxk"
      },
      "outputs": [],
      "source": [
        "messages = [\n",
        "    {\n",
        "        \"role\": \"system\",\n",
        "        \"content\": \"You are a helpful customer support assistant. Use the supplied tools to assist the user.\",\n",
        "    },\n",
        "    {\"role\": \"user\", \"content\": user_query},\n",
        "]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "Om7-nFHPTY0Z"
      },
      "outputs": [],
      "source": [
        "response = openai.chat.completions.create(\n",
        "    model=GPT_MODEL,\n",
        "    messages=messages,\n",
        "    tools=tools,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6NVOHTQg8zQg",
        "outputId": "45429a62-3fad-48ae-9bbd-cf7f980c055e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "ChatCompletionMessage(content=None, refusal=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_U7kkN9X2ohVasUD5ReOaRwcG', function=Function(arguments='{\"statement\":\"We are there\"}', name='shout'), type='function')])\n"
          ]
        }
      ],
      "source": [
        "# Append the message to messages list\n",
        "response_message = response.choices[0].message\n",
        "messages.append(response_message)\n",
        "\n",
        "print(response_message)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ExlEStFq9FxM",
        "outputId": "9176635d-2ea3-4e13-f53c-3b06d2fb92e9"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Debug - Tool call received: shout\n",
            "Debug - Arguments: {\"statement\":\"We are there\"}\n",
            "WE ARE THERE! How can I assist you further?\n"
          ]
        }
      ],
      "source": [
        "# Step 2: determine if the response from the model includes a tool call.\n",
        "tool_calls = response_message.tool_calls\n",
        "if tool_calls:\n",
        "    # If true the model will return the name of the tool / function to call and the argument(s)\n",
        "    tool_call = tool_calls[0]\n",
        "    tool_call_id = tool_call.id\n",
        "    tool_function_name = tool_call.function.name\n",
        "\n",
        "    print(f\"Debug - Tool call received: {tool_function_name}\")\n",
        "    print(f\"Debug - Arguments: {tool_call.function.arguments}\")\n",
        "\n",
        "    try:\n",
        "        tool_arguments = json.loads(tool_call.function.arguments)\n",
        "        tool_query_string = tool_arguments.get(\"statement\", \"\")\n",
        "    except json.JSONDecodeError:\n",
        "        print(\n",
        "            f\"Error: Unable to parse function arguments: {tool_call.function.arguments}\"\n",
        "        )\n",
        "        tool_query_string = \"\"\n",
        "\n",
        "    # Step 3: Call the function and retrieve results. Append the results to the messages list.\n",
        "    if tool_function_name == \"shout\":\n",
        "        results = shout(tool_query_string)\n",
        "\n",
        "        messages.append(\n",
        "            {\n",
        "                \"role\": \"tool\",\n",
        "                \"tool_call_id\": tool_call_id,\n",
        "                \"name\": tool_function_name,\n",
        "                \"content\": results,\n",
        "            }\n",
        "        )\n",
        "\n",
        "        # Step 4: Invoke the chat completions API with the function response appended to the messages list\n",
        "        # Note that messages with role 'tool' must be a response to a preceding message with 'tool_calls'\n",
        "        model_response_with_function_call = client.chat.completions.create(\n",
        "            model=\"gpt-4o\",\n",
        "            messages=messages,\n",
        "        )  # get a new response from the model where it can see the function response\n",
        "        print(model_response_with_function_call.choices[0].message.content)\n",
        "    else:\n",
        "        print(f\"Error: function {tool_function_name} does not exist\")\n",
        "else:\n",
        "    # Model did not identify a function to call, result can be returned to the user\n",
        "    print(response_message.content)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "state": {}
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
